'''
Date: 26th Oct 2021
Script to extract host level features from PCAPs of Tor traffic (malware & benign)
@Author: Priyanka G. Dodia
Qatar Computing Research Institute
(!)Pls cite the paper if you wish to integrate/re-use this code.
'''
import os, sys
from datetime import date, datetime
import statistics as stats
import collections

# Helpers
# Get total seconds in given timestamps
def calc_seconds(ts_lst):
	if len(ts_lst) == 1:
		secondsseen = 360
	else:
		secondsseen = (ts_lst[-1] - ts_lst[0]).total_seconds()
	return secondsseen

# tconns: dictionary: key: srcip_dstip_sport_dport, val: [srcip, dstip, sport, dport, dur, sentd, recvd, spkt, rpkt, connstate, ts]
def connlogs(tconns, dtype):
	connfts = []
    # Conn logs: total 28
	print(tconns)
    # No. of tor connections, src/dst ip is a Tor IP
	connfts += [len(tconns)] #1
	failed = 0
	failed_ts = []
	all_ts = []
	dst_ports = []
	durations = []
	totalpkts = []
	totalspkts = []
	totalrpkts = []
	totaldata = []
	totalrecv = []
	totalsent = []
    # No. of unique Tor IPs
	for key, vals in tconns.items():
		for val in vals:
			srcip = val[0]
			dstip = val[1]
			if not (val[2] == "-" and val[3] == "-"):
				sport = int(val[2])
				dport = int(val[3])
			else:
				sport = -1
				dport = -1
			if not val[4] == "-":
				dur = float(val[4])
			else:
				dur = 0
			if not val[5] == "-":
				sentd = float(val[5])
			else:
				sentd = 0
			if not val[6] == "-":
				recvd = float(val[6])
			else:
				recvd = 0
			if not val[7] == "-":
				spkt = int(val[7])
			else:
				spkt = 0
			if not val[8] == "-":
				rpkt = int(val[8])
			else:
				rpkt = 0
			if not val[9] == "-":
				connstate = val[9]
			else:
				connstate = None

			ts = float(val[10])
			dtts = datetime.fromtimestamp(ts)
    		# No. of failed/rejected Tor conn attempts: conn_state = S0/REJ
			if connstate == "S0" or connstate == "REJ":
				failed += 1
				failed_ts += [dtts]

			all_ts += [dtts]
    		# Dst Ports used
			if not dport == -1:
				dst_ports += [dport]
			if not durations == None:
				durations += [dur]
			if not (spkt == None and rpkt == None):
				totalpkts += [spkt+rpkt]
				totalspkts += [spkt]
				totalrpkts += [rpkt]
			if not (sentd == None and recvd == None):
				totaldata += [sentd+recvd]
				totalrecv += [recvd]
				totalsent += [sentd]

	# Rate of Tor connections: Avg no. of Tor conns/second*
	connfts += [failed] #2
	all_ts.sort()
	if len(all_ts) >= 1:
		secondsseen = calc_seconds(all_ts)
		connfts += [len(all_ts)/secondsseen] #3
	else:
		connfts += [0]
	# Rate of failed: No. of failed attempts/second*
	if failed >= 1:
		failed_ts.sort()
		#print(failed_ts)
		totalseconds = calc_seconds(failed_ts)
		connfts += [failed/totalseconds] #4
	else:
		connfts += [0]

	# No. of unique dst ports used across Tor conns*
	if len(dst_ports) >= 1:
		uniqdports = set(dst_ports)
		connfts += [len(uniqdports)] #5
		print(dst_ports, uniqdports)
		# Most used dst port
		freq = collections.Counter(dst_ports)
		print(freq)
		connfts += [freq.most_common()[0][0]] #6
    	# No. of non-standard dst ports seen in Tor conns: other than 9001, 9020, 9030, 9010, 9050, 443, 80*
		nonstd = 0
		allnonstd = []
		for elem in freq:
			if not (elem >= 9000 and elem <= 9010  or elem in [9020, 9030, 9050, 9150, 443, 80, 8080]):
				print("Nonstd:", elem)
				open("nonstd_ports_"+dtype, "a+").write(str(elem)+"\n")
				nonstd += 1
				allnonstd += [elem]
		connfts += [nonstd] #7
		# Most frequent non standard dst port used
		if nonstd >= 1:
			freqnstd = collections.Counter(allnonstd)
			print(freqnstd)
			connfts += [freqnstd.most_common()[0][0]] #8
		else:
			connfts += [0]
	else:
		assert True == False # dont expect to come here
		connfts += [0, 0, 0]

    # Avg duration of Tor connections*
	connfts += [stats.mean(durations)] #9
	durations.sort()
	# Smallest duration Tor conn seen per host/pcap*
	smalldur = durations[0] #10
	# Max duration Tor conn per pcap/host - seconds*
	maxdur = durations[-1] #11
	##print(durations, smalldur)
	connfts += [smalldur, maxdur]
	# No. of short duration Tor connections: lasted max a minute:
	shortdurcount = 0
	for dur in durations:
		if dur <= 60:
			shortdurcount += 1
	connfts += [shortdurcount] #12

	# Avg time gap between each Tor connection: too many tor connections in short period of time?*
	difftimegap = 0
	c = 0
	for ti in range(0,len(all_ts)-1):
		print(all_ts[ti+1], all_ts[ti])
		diff = (all_ts[ti+1]-all_ts[ti]).total_seconds()
		difftimegap += diff
		c += 1

	if not c == 0:
		avggap = difftimegap/len(all_ts) #corrected this */c earlier
	else:
		avggap = 0
	connfts += [avggap] #13
	# Mean/median/mode of total sent/received packets in each Tor conn*
	connfts += [stats.mean(totalpkts)] #14
	connfts += [stats.median(totalpkts)] #15
	connfts += [stats.mode(totalpkts)] #16
	# Mean/median/mode amt of data exchanged in Tor conns*
	connfts += [stats.mean(totaldata)] #17
	connfts += [stats.median(totaldata)] #18
	connfts += [stats.mode(totaldata)] #19
	# Mean/median/mode amt data sent in all tor conns*
	connfts += [stats.mean(totalsent)] #20
	connfts += [stats.median(totalsent)] #21
	connfts += [stats.mode(totalsent)] #22
	# Mean/median/mode no. of packets sent in all tor conns*
	connfts += [stats.mean(totalspkts)] #23
	connfts += [stats.median(totalspkts)] #24
	connfts += [stats.mode(totalspkts)] #25
	# Mean/median/mode no. of packets received in all tor conns*
	connfts += [stats.mean(totalrpkts)] #26
	connfts += [stats.median(totalrpkts)] #27
	connfts += [stats.mode(totalrpkts)] #28
	# Mean/median/mode amt data recv in all tor conns*
	connfts += [stats.mean(totalrecv)] #29
	connfts += [stats.median(totalrecv)] #30
	connfts += [stats.mode(totalrecv)] #31
	print(connfts, len(connfts))
	assert len(connfts) == 31
	return connfts


'''
fields:
ts(0) uid(1) id.orig_h(2)id.orig_p(3)id.resp_h(4)id.resp_p(5)proto(6)trans_id(7)rtt(8)query(9)
qclass(10)qclass_name(11)qtype(12)qtype_name(13)rcode(14) rcode_name(15)AA(16)TC(17)RD(18)
RA(19)Z(20) answers(21) TTLs(22)rejected(23)
indices: 11-17
'''
def dnslogs(tconns):
	dnsfts = []
	# Total 6 fts
	nxdomains = 0
	refused = 0
	servfail = 0
	leakedonions = []
	leakedonions_ts = []
	rej_onions = []
	for key, vals in tconns.items():
		for val in vals:
			if len(val) >= 18:
				ts = val[11]
				rname = val[14]
				query = val[12]
				reject = val[16]
				if not rname == "-":
					rname = rname.lower()
					if rname == "nxdomains":
						nxdomains += 1
					elif rname == "refused" or reject == "T":
						refused += 1
						if ".onion" in query:
							rej_onions += [query]
					elif rname == "servfail":
						servfail += 1
				if ".onion" in query:
					leakedonions += [query]
					leakedonions_ts += [ts]

	# Total no. of dns queries rcode_name: rcode:599: NXDOMAINS - 32
	dnsfts += [nxdomains]
    # Total no. of dns queries rcode_name: rcode:596: REFUSED - 33
	dnsfts += [refused]
    # Total no. of dns queries rcode_name: rcode: 2 : SERVFAIL - 34
	dnsfts += [servfail]
    # Total no. of onion domain accesses - 35
	dnsfts += [len(leakedonions)]
    # Total no. of unique onion domains accessed (seen in dns logs) - 36
	dnsfts += [len(set(leakedonions))]
    # Total no. of 'rejected' onion domain queries - 37
	dnsfts += [len(rej_onions)]

	return dnsfts

# indices: 18-19
def httplogs(tconns):
	httpfts = []
	onions = 0
	consensus = 0
	tor = 0
	for key, vals in tconns.items():
		for val in vals:
			if len(val) >= 20:
				host = val[18]
				uri = val[19]
				if ".onion" in host or ".onion" in uri:
					onions += 1
				if "consensus" in host or "consensus" in uri:
					consensus += 1
				if "\tor" in host or "\tor" in uri:
					tor += 1
    # Total no. of onion domains accessed - 38
	httpfts += [onions]
    # Total no. of links with 'consensus' seen - 39
	httpfts += [consensus]
    # Total no. of URLs with "tor" keyword - 40
	httpfts += [tor]

	return httpfts

def get_logs(logf):
    logdt = dict()
    with open(logf, "r") as ff:
        for line in ff.readlines():
            if "#" in line:
                continue
            line = line.rstrip().split("\x09")
            uid = line[1]
            if uid not in logdt:
                logdt[uid] = line

    return logdt
'''
# Input: 
'dtype': data type (benign/malware)
'fpath': path to PCAP for which host features will be extracted
'torconns': A dictionary with verified Tor connections and it's zeek fields of interest
key: SRCIP_DSTIP_SRCPORT_DSTPORT, value: list of top 3 highly active Tor connection fields 
value fields: srcip, dstip, sport, dport, dur, sent bytes, received bytes, sent packets, received packets, connstate, ts]
Sample dictionary provided for reference.
'''
def host_features(dtype, fpath="your.pcap", torconns={'192.168.242.210_37.187.102.108_49168_443': [['192.168.242.210', '37.187.102.108', '49168','443', '5.108209', '2050', '3931', '11', '12', 'RSTO', '1633006175.569994']],'192.168.242.210_37.157.195.87_49167_443': [['192.168.242.210', '37.157.195.87', '49167', '443', '6.102069', '24436','684230', '184', '496', 'S0', '1633006174.576220']],'192.168.242.210_86.59.21.38_49169_443':[['192.168.242.210', '86.59.21.38', '49169', '443', '5.106664', '2026', '3901', '10', '10', 'RSTO', '1633006175.572083']]}):
	print("Extracting Host features for\n", fpath)
	print("Total tor conns: ",len(torconns))
	connfts = connlogs(torconns, dtype)
	dnsfts = dnslogs(torconns)
	httpfts = httplogs(torconns)
	ftlst = connfts+dnsfts+httpfts
	assert len(ftlst) == 40
	return ftlst

#host_features()
